package cn.seecoder.ai.algorithm.ml.classification

import cn.seecoder.ai.algorithm.ml.BaseClassification
import cn.seecoder.ai.model.bo.TrainParamsBO
import org.apache.spark.ml.PipelineStage
import org.apache.spark.ml.classification.NaiveBayes

class NaiveBayesClassification extends BaseClassification{


  override def buildMachineLearningStage(trainParams: TrainParamsBO): PipelineStage = {

    val nb = new NaiveBayes()
      .setFeaturesCol("indexedFeatures")
      .setLabelCol("label")


    nb
  }
}
